public class BinarySearchTree {
    static class TreeNode {
        int val;
        TreeNode left;
        TreeNode right;

        public TreeNode(int val) {
            this.val = val;
        }
    }

    public TreeNode root;

    public TreeNode search(int val) {
        TreeNode cur = root;
        while (cur != null) {
            if (cur.val == val) {
                return cur;
            }
            cur = cur.val > val ? cur.right : cur.left;
        }
        return null;
    }

    public boolean insert(int val) {
        TreeNode node = new TreeNode(val);
        if (root == null) {
            root = node;
            return true;
        }
        TreeNode parent = null;
        TreeNode cur = root;
        while (cur != null) {
            parent = cur;
            if(cur.val < val) {
                cur = cur.right;
            } else if(cur.val > val) {
                cur = cur.left;
            } else {
                return false;
            }
        }
        //判断val大小，不能判断cur在左右，因为左右都是空，cur也是空
        if (parent.val < val) {
            parent.right = node;
        } else {
            parent.left = node;
        }
        return true;
    }

    public void remove(int val) {
        TreeNode parent = null;
        TreeNode cur = root;
        while (cur != null) {
            parent = cur;
            if (cur.val > val) {
                cur = cur.left;
            } else if (cur.val < val) {
                cur = cur.right;
            } else {
                removeNode(parent, cur);
                return;
            }
        }
    }

    private void removeNode(TreeNode parent, TreeNode cur) {
        if(cur.left == null) {
            if(cur == root) {
                root = cur.right;
            } else if (parent.left == cur) {
                parent.left = cur.right;
            } else {
                parent.right = cur.right;
            }
        }else if(cur.right == null) {
            if(cur == root) {
                root = cur.left;
            }else if(parent.left == cur) {
                parent.left = cur.left;
            }else {
                parent.right = cur.left;
            }
        }else {
            TreeNode targetParent = cur;
            TreeNode target = cur.right;
            while(target.left != null) {
                targetParent = target;
                target = target.left;
            }
            cur.val = target.val;
            if(target == targetParent.left) {
                targetParent.left = target.right;
            }else {
                targetParent.right = target.right;
            }
        }
    }
}
